import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torchvision.models.resnet import BasicBlock, resnet50

from custom_functional import compute_grad_mag


class GatedSpatialConv2d(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(GatedSpatialConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
                                                 False, _pair(0), groups, bias, 'zeros')

        self.gate_conv = nn.Sequential(
            nn.BatchNorm2d(in_channels + 1),
            nn.Conv2d(in_channels + 1, in_channels + 1, 1),
            nn.ReLU(True),
            nn.Conv2d(in_channels + 1, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, input_features, gating_features):
        """
        :param input_features:  [NxCxHxW]
        :param gating_features: [Nx1xHxW]
        :return:
        """
        alphas = self.gate_conv(torch.cat([input_features, gating_features], dim=1))
        input_features = (input_features * (alphas + 1))
        return F.conv2d(input_features, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    def reset_parameters(self):
        nn.init.xavier_normal_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()


class Crop(nn.Module):
    def __init__(self, axis, offset):
        super(Crop, self).__init__()
        self.axis = axis
        self.offset = offset

    def forward(self, x, ref):
        """
        :param x: input layer
        :param ref: reference usually data in
        :return:
        """
        for axis in range(self.axis, x.dim()):
            ref_size = ref.size(axis)
            indices = torch.arange(self.offset, self.offset + ref_size).long()
            indices = x.data.new().resize_(indices.size()).copy_(indices).long()
            x = x.index_select(axis, Variable(indices))
        return x


class MyIdentity(nn.Module):
    def __init__(self, axis, offset):
        super(MyIdentity, self).__init__()
        self.axis = axis
        self.offset = offset

    def forward(self, x, ref):
        return x


class SideOutputCrop(nn.Module):
    """
    This is the original implementation ConvTranspose2d (fixed) and crops
    """

    def __init__(self, num_output, kernel_sz=None, stride=None, upconv_pad=0, do_crops=True):
        super(SideOutputCrop, self).__init__()
        self._do_crops = do_crops
        self.conv = nn.Conv2d(num_output, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

        if kernel_sz is not None:
            self.upsample = True
            self.upsampled = nn.ConvTranspose2d(1, out_channels=1, kernel_size=kernel_sz, stride=stride,
                                                padding=upconv_pad,
                                                bias=False)
            self.crops = Crop(2, offset=kernel_sz // 4) if self._do_crops else MyIdentity(None, None)
        else:
            self.upsample = False

    def forward(self, res, reference=None):
        side_output = self.conv(res)
        if self.upsample:
            side_output = self.upsampled(side_output)
            side_output = self.crops(side_output, reference)

        return side_output


class ASPP(nn.Module):
    def __init__(self, in_dim, reduction_dim=256, rates=[6, 12, 18]):
        super(ASPP, self).__init__()
        self.features = []
        self.features.append(
            nn.Sequential(
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True))
        )
        for r in rates:
            self.features.append(nn.Sequential(
                nn.Conv2d(in_dim, reduction_dim, kernel_size=3, dilation=r, padding=r, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = torch.nn.ModuleList(self.features)

        # img level features
        self.img_pooling = nn.AdaptiveAvgPool2d(1)
        self.img_conv = nn.Sequential(
            nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(reduction_dim), nn.ReLU(inplace=True)
        )
        self.edge_conv = nn.Sequential(
            nn.Conv2d(1, reduction_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(reduction_dim), nn.ReLU(inplace=True)
        )

        self.out = nn.Sequential(
            nn.Conv2d(reduction_dim * 6, reduction_dim, 1),
            nn.BatchNorm2d(reduction_dim),
            nn.ReLU(True)
        )

    def forward(self, x, edge):
        x_size = x.size()

        img_features = self.img_pooling(x)
        img_features = self.img_conv(img_features)
        img_features = F.interpolate(img_features, x_size[2:], mode='bilinear', align_corners=True)

        edge_features = F.interpolate(edge, x_size[2:], mode='bilinear', align_corners=True)
        edge_features = self.edge_conv(edge_features)

        out = torch.cat([img_features, edge_features], dim=1)

        for f in self.features:
            y = f(x)
            out = torch.cat((out, y), 1)
        out = self.out(out)
        return out


class GSCNN(nn.Module):
    def __init__(self, num_classes, pretrained=False):
        super(GSCNN, self).__init__()
        res = resnet50(pretrained=pretrained)
        self.layer0 = nn.Sequential(res.conv1, res.bn1, res.relu)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = res.layer1
        self.layer2 = res.layer2
        self.layer3 = res.layer3
        self.layer4 = res.layer4
        del res

        self.dsn1 = nn.Conv2d(256, 1, 1)
        self.dsn2 = nn.Conv2d(512, 1, 1)
        self.dsn3 = nn.Conv2d(1024, 1, 1)

        self.res1 = nn.Sequential(
            BasicBlock(64, 64),
            nn.Conv2d(64, 32, 1)
        )
        self.res2 = nn.Sequential(
            BasicBlock(32, 32),
            nn.Conv2d(32, 16, 1)
        )
        self.res3 = nn.Sequential(
            BasicBlock(16, 16),
            nn.Conv2d(16, 8, 1)
        )
        self.fuse = nn.Conv2d(8, 1, 1)

        self.cw = nn.Conv2d(2, 1, 1)

        self.gate1 = GatedSpatialConv2d(32, 32)
        self.gate2 = GatedSpatialConv2d(16, 16)
        self.gate3 = GatedSpatialConv2d(8, 8)

        self.aspp = ASPP(2048)

        self.bot_fine = nn.Conv2d(256, 48, kernel_size=1, bias=False)

        self.final_seg = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, kernel_size=1, bias=False))

        self.sigmoid = nn.Sigmoid()
        initialize_weights(self.final_seg)

    def forward(self, inp):
        x_size = inp.size()

        # Canny
        im_arr = inp.cpu().numpy().transpose((0, 2, 3, 1)).astype(np.uint8)
        canny = np.zeros((x_size[0], 1, x_size[2], x_size[3]))
        for i in range(x_size[0]):
            canny[i] = cv2.Canny(im_arr[i], 10, 100)
        canny = torch.from_numpy(canny).cuda().float()

        m1 = self.layer0(inp)
        m2 = self.layer1(self.maxpool(m1))
        m3 = self.layer2(m2)
        m4 = self.layer3(m3)
        m5 = self.layer4(m4)

        m1f = F.interpolate(m1, x_size[2:], mode='bilinear', align_corners=True)
        cs = self.res1(m1f)
        s1 = F.interpolate(self.dsn1(m2), x_size[2:], mode='bilinear', align_corners=True)
        cs = self.gate1(cs, s1)
        cs = self.res2(cs)
        s2 = F.interpolate(self.dsn2(m3), x_size[2:], mode='bilinear', align_corners=True)
        cs = self.gate2(cs, s2)
        cs = self.res3(cs)
        s3 = F.interpolate(self.dsn3(m4), x_size[2:], mode='bilinear', align_corners=True)
        cs = self.gate3(cs, s3)
        cs = self.fuse(cs)
        edge_out = self.sigmoid(cs)
        cat = torch.cat((edge_out, canny), dim=1)
        acts = self.cw(cat)
        acts = self.sigmoid(acts)

        dec0_up = self.aspp(m5, acts)
        dec0_up = F.interpolate(dec0_up, m2.size()[2:], mode='bilinear', align_corners=True)
        dec0_fine = self.bot_fine(m2)
        dec0 = [dec0_fine, dec0_up]
        dec0 = torch.cat(dec0, 1)

        dec1 = self.final_seg(dec0)
        seg_out = F.interpolate(dec1, x_size[2:], mode='bilinear', align_corners=True)

        return [seg_out, edge_out]


def perturbate_input_(input, n_elements=200):
    N, C, H, W = input.shape
    assert N == 1
    c_ = np.random.random_integers(0, C - 1, n_elements)
    h_ = np.random.random_integers(0, H - 1, n_elements)
    w_ = np.random.random_integers(0, W - 1, n_elements)
    for c_idx in c_:
        for h_idx in h_:
            for w_idx in w_:
                input[0, c_idx, h_idx, w_idx] = 1
    return input


def _sample_gumbel(shape, eps=1e-10):
    """
    Sample from Gumbel(0, 1)

    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = torch.rand(shape).cuda()
    return - torch.log(eps - torch.log(U + eps))


def _gumbel_softmax_sample(logits, tau=1.0, eps=1e-10):
    """
    Draw a sample from the Gumbel-Softmax distribution

    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    (MIT license)
    """
    assert logits.dim() == 3
    gumbel_noise = _sample_gumbel(logits.size(), eps=eps)
    y = logits + gumbel_noise
    return F.softmax(y / tau, 1)


def _one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.

    Args:
      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

    Returns:
      (tensor) encoded labels, sized [N, #classes].
    """

    y = torch.eye(num_classes).cuda()
    return y[labels].permute(0, 3, 1, 2)


class DualTaskLoss(nn.Module):
    def __init__(self, num_classes, cuda=False):
        super(DualTaskLoss, self).__init__()
        self._cuda = cuda
        self.num_classes = num_classes
        return

    def forward(self, input_logits, gts, ignore_pixel=255):
        """
        :param input_logits: NxCxHxW
        :param gt_semantic_masks: NxCxHxW
        :return: final loss
        """
        N, C, H, W = input_logits.shape
        th = 1e-8  # 1e-10
        eps = 1e-10
        ignore_mask = (gts == ignore_pixel).detach()
        input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, self.num_classes, H, W),
                                   torch.zeros(N, C, H, W).cuda(),
                                   input_logits)
        gt_semantic_masks = gts.detach()
        gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N, H, W).long().cuda(), gt_semantic_masks)
        gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, self.num_classes).detach()

        g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5)
        g = g.reshape((N, C, H, W))
        g = compute_grad_mag(g, cuda=self._cuda)

        g_hat = compute_grad_mag(gt_semantic_masks, cuda=self._cuda)

        g = g.view(N, -1)
        g_hat = g_hat.view(N, -1)
        loss_ewise = F.l1_loss(g, g_hat, reduction='none', reduce=False)

        p_plus_g_mask = (g >= th).detach().float()
        loss_p_plus_g = torch.sum(loss_ewise * p_plus_g_mask) / (torch.sum(p_plus_g_mask) + eps)

        p_plus_g_hat_mask = (g_hat >= th).detach().float()
        loss_p_plus_g_hat = torch.sum(loss_ewise * p_plus_g_hat_mask) / (torch.sum(p_plus_g_hat_mask) + eps)

        total_loss = 0.5 * loss_p_plus_g + 0.5 * loss_p_plus_g_hat

        return total_loss


class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None, ignore_index=255):
        super(CrossEntropyLoss2d, self).__init__()
        self.nll_loss = nn.NLLLoss2d(weight, ignore_index, reduction='mean')

    def forward(self, inputs, targets):
        return self.nll_loss(F.log_softmax(inputs, dim=1), targets)


class JointEdgeSegLoss(nn.Module):
    def __init__(self, classes, edge_weight=1, seg_weight=1, att_weight=1, dual_weight=1):
        super(JointEdgeSegLoss, self).__init__()
        self.num_classes = classes
        self.seg_loss = CrossEntropyLoss2d().cuda()

        self.edge_weight = edge_weight
        self.seg_weight = seg_weight
        self.att_weight = att_weight
        self.dual_weight = dual_weight
        self.dual_task = DualTaskLoss(classes)

    def bce2d(self, input, target):
        log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)  # [B, H*W*C]
        target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)  # [B, H*W*C]
        target_trans = target_t.clone()

        pos_index = (target_t == 1)
        neg_index = (target_t == 0)
        ignore_index = (target_t > 1)

        target_trans[pos_index] = 1
        target_trans[neg_index] = 0

        pos_index = pos_index.data.cpu().numpy().astype(bool)
        neg_index = neg_index.data.cpu().numpy().astype(bool)
        ignore_index = ignore_index.data.cpu().numpy().astype(bool)

        weight = torch.Tensor(log_p.size()).fill_(0)
        weight = weight.numpy()
        pos_num = pos_index.sum()
        neg_num = neg_index.sum()
        sum_num = pos_num + neg_num
        weight[pos_index] = neg_num * 1.0 / sum_num
        weight[neg_index] = pos_num * 1.0 / sum_num

        weight[ignore_index] = 0

        weight = torch.from_numpy(weight)
        weight = weight.cuda()
        loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, reduction='mean')
        return loss

    def forward(self, inputs, targets):
        segin, edgein = inputs
        segmask, edgemask = targets
        seg_loss = self.seg_weight * self.seg_loss(segin, segmask)
        edge_loss = self.edge_weight * 20 * self.bce2d(edgein, edgemask)
        dual_loss = self.dual_weight * self.dual_task(segin, segmask)
        loss = seg_loss + edge_loss + dual_loss
        return loss.mean()


if __name__ == "__main__":
    model = GSCNN(2)
    a = torch.zeros((2, 3, 512, 512))
    res = model(a)
    print(res[0].shape)
